# Calculating emission rates for each pollutant using OGE data
library(pacman)
p_load(
  here, data.table, fst, ggplot2, readxl, magrittr, 
  fixest, broom, dplyr, purrr, stringr, janitor, 
  lubridate, collapse
)

# Function that calculates plant-level average emissions rates 
# for a particular NERC region, using modeled plants only
calc_avg_emissions_rates_nerc = function(
  nerc_adj_in, 
  oge_plant_info_dt,
  time_bounds =  c(ymd_hms('2019-01-01 10:00:00'), ymd_hms('2019-12-31 23:00:00'))
){
  # Loading data --------------------------------------------
  print(paste0(nerc_adj_in, ': Loading data'))
  # Getting file paths for all plants in region
  plant_gen_fp = list.files(
    here('Data/electricity-generation/plant-model-fit/plant-gen-dt'),
    full.names = TRUE
  )
  nerc_plant_gen_fp = 
    plant_gen_fp[
      str_extract(plant_gen_fp, '(?<=-)\\d*(?=\\.fst)') |> as.integer()
      %in% oge_plant_info_dt[nerc_adj == nerc_adj_in]$plant_id_eia
    ]
  # Hourly emissions data
  elec_gen_dt =
    map_dfr(
      nerc_plant_gen_fp, 
      read.fst, 
      as.data.table = TRUE
    )[datetime_utc %between% c(time_bounds[1], time_bounds[2]) & 
      # Dropping hours that were above nameplate capacity
      net_generation_mwh_raw == net_generation_mwh
    ] |> 
    setkey(plant_id_eia, datetime_utc)
  # Calculating annual average emissions ------------------
  print(paste0(nerc_adj_in, ': calculating rates'))
  emissions_dt_raw = 
    elec_gen_dt[,.(
      tot_co2e_tons_for_electricity = sum(co2e_mass_lb_for_electricity, na.rm = TRUE)/2000,
      tot_nox_tons_for_electricity = sum(nox_mass_lb_for_electricity, na.rm = TRUE)/2000,
      tot_so2_tons_for_electricity = sum(so2_mass_lb_for_electricity, na.rm = TRUE)/2000,
      tot_net_generation_mwh = sum(net_generation_mwh, na.rm = TRUE),
      tot_fuel_consumed_for_electricity_mmbtu = 
        sum(fuel_consumed_for_electricity_mmbtu, na.rm = TRUE)), 
      keyby = .(plant_id_eia)
    ]
  return(emissions_dt_raw)
  print(paste0(nerc_adj_in, ': done.'))
}


# Function to add PM2.5 emission rates to the plant-level data
calc_avg_emissions_rates = function(
  nerc_adj_in = c('CAL','MRO','NPCC','RFC','SERC','TRE','WECC'),
  time_bounds =  c(ymd_hms('2019-01-01 10:00:00'), ymd_hms('2019-12-31 23:00:00'))
){
  # Unit information table 
  oge_plant_info_dt =
    read.fst(
      path = here("Data/electricity-generation/plant-info-dt-oge.fst"),
      as.data.table = TRUE
    )[year %in% unique(year(time_bounds))
    ][,.SD[which.min(year)], by= plant_id_eia]
  # generation by unit in raw CEMS data
  unit_weight_dt = 
    map_dfr(
      unique(year(time_bounds)),
      \(yr){
        read.fst(
          here(paste0("Data/electricity-generation/cems/raw_cems_",yr,".fst")),
          as.data.table = TRUE
        )[,year := yr][,
          .(tot_heat_input_mm_btu = sum(heat_input_mm_btu, na.rm = TRUE)), 
          keyby = .(orispl_code, unitid, year)
        ]
      }
    )[,
      .(tot_heat_input_mm_btu = sum(tot_heat_input_mm_btu, na.rm = TRUE)), 
      keyby = .(orispl_code, unitid)
    ]
  invisible(gc())
  # Reading in egrid PM2.5 emissions rates
  egrid_unit_pm_dt = 
    read_xlsx(
      path = here('Data/electricity-generation/egrid/egrid_draft_pm2.5_emissions_7-20-20.xlsx'),
      sheet = '2018 PM Unit-level Data',
      skip = 1
    ) |> 
    data.table() %>% 
    # Two plants have duplicates because of multiple prime mover types
    # Taking the max of the two
    .[PRMVR != 'HY',
      # PM25RT is in lbs/mmbtu, converting to tons
      .(pm_tons_per_mmbtu = max(PM25RT/2000, na.rm = TRUE)),
      keyby = .(orispl_code = ORISPL, unitid = UNITID)
    ] %>% .[!is.nan(pm_tons_per_mmbtu) & !is.infinite(pm_tons_per_mmbtu)] 
  # Converting unit PM rates to plant PM rates 
  egrid_plant_pm_dt = 
    merge(
      egrid_unit_pm_dt,
      unit_weight_dt,
      by = c('orispl_code','unitid'),
      all.x = TRUE
    ) %>%
    .[,':='(
      n = .N,
      wt = ifelse(is.na(tot_heat_input_mm_btu),0,tot_heat_input_mm_btu)/
        sum(tot_heat_input_mm_btu, na.rm =TRUE)), 
      by = .(orispl_code)
    ] %>% 
    .[,wt := ifelse(is.nan(wt), 1/n, wt)] %>% 
    .[,.(
      pm_tons_per_mmbtu = weighted.mean(pm_tons_per_mmbtu, w = wt, na.rm = TRUE)),
      keyby = .(plant_id_eia = orispl_code)
    ]
  # Running for each nerc region
  emissions_dt_raw = 
    map_dfr(
      nerc_adj_in, 
      calc_avg_emissions_rates_nerc,
      oge_plant_info_dt = oge_plant_info_dt
    )
  # Calculating emissions rates for each pollutant
  emissions_dt = 
    join(
      emissions_dt_raw,
      egrid_plant_pm_dt,
      on = c('plant_id_eia'),
      how = 'left'
    )[,.(
      plant_id_eia,  
      co2e_tons_per_mwh = ifelse(
        tot_net_generation_mwh <= 0, 0,
        tot_co2e_tons_for_electricity/tot_net_generation_mwh
      ), 
      nox_tons_per_mwh = ifelse(
        tot_net_generation_mwh <= 0, 0,
        tot_nox_tons_for_electricity/tot_net_generation_mwh
      ), 
      so2_tons_per_mwh = ifelse(
        tot_net_generation_mwh <= 0, 0,
        tot_so2_tons_for_electricity/tot_net_generation_mwh
      ), 
      pm25_tons_per_mwh = ifelse(
        tot_net_generation_mwh <= 0, 0,
        pm_tons_per_mmbtu*tot_fuel_consumed_for_electricity_mmbtu/
          tot_net_generation_mwh
      ),
      tot_net_generation_mwh
    )]
  # Checking how many observations are missing
  missing_pm25_rate = emissions_dt[, .(
    fmean(is.na(pm25_tons_per_mwh), w = tot_net_generation_mwh)
  )]$V1
  print(paste0(
    round(100*missing_pm25_rate,1),
    '% of mwh missing pm2.5 emissions. Expecting 8-10% (ish).'
  ))
  if(missing_pm25_rate > 0.1) stop('too many missing pm2.5')
  # Calulating 95th percentile and median emissions rates by nerc and fuel type 
  nerc_median_emissions_dt = 
    join(
      emissions_dt, 
      oge_plant_info_dt,
      on = 'plant_id_eia'
      # One plant is wind in 2020 but natural gas in 2019. No other shaped wind 
      # plants so switching it to natural gas
    )[,fuel_category := ifelse(plant_id_eia == 62736, 'natural_gas', fuel_category)
    ][,.(
      pm25_tons_per_mwh_p50 = fnth(pm25_tons_per_mwh, n = 0.50),
      pm25_tons_per_mwh_p95 = fnth(pm25_tons_per_mwh, n = 0.95),
      co2e_tons_per_mwh_p50 = fnth(co2e_tons_per_mwh, n = 0.50),
      co2e_tons_per_mwh_p95 = fnth(co2e_tons_per_mwh, n = 0.95),
      so2_tons_per_mwh_p50 = fnth(so2_tons_per_mwh, n = 0.50),
      so2_tons_per_mwh_p95 = fnth(so2_tons_per_mwh, n = 0.95),
      nox_tons_per_mwh_p50 = fnth(nox_tons_per_mwh, n = 0.50),
      nox_tons_per_mwh_p95 = fnth(nox_tons_per_mwh, n = 0.95)),
      by = .(nerc_adj, fuel_category)
    ]
  # Also want median emissions rates by fuel category 
  median_emissions_dt = 
    join(
      emissions_dt, 
      oge_plant_info_dt,
      on = 'plant_id_eia'
    )[,.(
      pm25_tons_per_mwh_p50_fc = fnth(pm25_tons_per_mwh, n = 0.50),
      pm25_tons_per_mwh_p95_fc = fnth(pm25_tons_per_mwh, n = 0.95),
      co2e_tons_per_mwh_p50_fc = fnth(co2e_tons_per_mwh, n = 0.50),
      co2e_tons_per_mwh_p95_fc = fnth(co2e_tons_per_mwh, n = 0.95),
      so2_tons_per_mwh_p50_fc = fnth(so2_tons_per_mwh, n = 0.50),
      so2_tons_per_mwh_p95_fc = fnth(so2_tons_per_mwh, n = 0.95),
      nox_tons_per_mwh_p50_fc = fnth(nox_tons_per_mwh, n = 0.50),
      nox_tons_per_mwh_p95_fc = fnth(nox_tons_per_mwh, n = 0.95)
      ),
      keyby = .(fuel_category)
    ]
  # Setting missing values to median and censoring high values
  emissions_dt2 = 
    join(
      emissions_dt, 
      oge_plant_info_dt,
      on = 'plant_id_eia'
    ) |>
    join(
      nerc_median_emissions_dt,
      on = c('nerc_adj','fuel_category'),
      how = 'left'
    ) |>
    join(
      median_emissions_dt,
      on = 'fuel_category',
      how = 'left'
    ) %>%
    .[,.(
      plant_id_eia,
      pm25_tons_per_mwh = fcase(
        is.na(pm25_tons_per_mwh) & !is.na(pm25_tons_per_mwh_p50), 
          pm25_tons_per_mwh_p50, 
        is.na(pm25_tons_per_mwh) & !is.na(pm25_tons_per_mwh_p50_fc), 
          pm25_tons_per_mwh_p50_fc,
        pm25_tons_per_mwh <= pm25_tons_per_mwh_p95,
          pm25_tons_per_mwh,
        pm25_tons_per_mwh > pm25_tons_per_mwh_p95, 
          pm25_tons_per_mwh_p95
      ),
      co2e_tons_per_mwh = ifelse(
        co2e_tons_per_mwh > co2e_tons_per_mwh_p95, co2e_tons_per_mwh_p95,
        co2e_tons_per_mwh
      ),
      so2_tons_per_mwh = ifelse(
        so2_tons_per_mwh > so2_tons_per_mwh_p95, so2_tons_per_mwh_p95,
        so2_tons_per_mwh
      ),
      nox_tons_per_mwh = ifelse(
        nox_tons_per_mwh > nox_tons_per_mwh_p95, nox_tons_per_mwh_p95,
        nox_tons_per_mwh
      ),
      tot_net_generation_mwh
    )]
  # Checking if there are any NA's
  if(nrow(na.omit(emissions_dt2)) < nrow(emissions_dt2)) {
    stop('missing avg emissions rates')
  }
  # Saving the results 
  fwrite(
    emissions_dt2,
    file = here('Data/electricity-generation/output/emissions-rate-avg-oge.csv')
  )
  fwrite(
    median_emissions_dt, 
    file = here('Data/electricity-generation/output/emissions-rate-avg-fc-oge.csv')
  )
}



# Fitting model with different emissions rates above/below median
emissions_rates_median = function(plant_id_eia_in, elec_gen_dt, split = 0.5, print = FALSE){
  if(print == TRUE) {print(plant_id_eia_in)}
  # Filtering to single unit, removing hours with no electricity but positive emissions
  plant_dt = elec_gen_dt[.(plant_id_eia_in)][
    !(net_generation_mwh == 0 & (
      nox_mass_lb_for_electricity > 0
      |so2_mass_lb_for_electricity > 0
      |co2e_mass_lb_for_electricity > 0
    )) 
  ]
  # Finding quantile of generation
  plant_quants = quantile(
    plant_dt[net_generation_mwh>0]$net_generation_mwh, 
    probs = split
  )
  plant_min = min(plant_dt$net_generation_mwh)
  # Adding indicators 
  plant_dt[, ':='(
    net_gen_over_split = net_generation_mwh >= plant_quants[1] & net_generation_mwh > plant_min,
    net_gen_msplit = net_generation_mwh - plant_quants[1]
  )]
  # Checking for some variance net generation
  if(var(plant_dt$net_gen_over_split) > 0){
    fml_in = c(
      nox_mass_lb_for_electricity, 
      so2_mass_lb_for_electricity, 
      co2e_mass_lb_for_electricity
      ) ~ -1 + net_generation_mwh + i(net_gen_over_split, net_gen_msplit, ref = TRUE)
  }else{
    fml_in = c(
      nox_mass_lb_for_electricity, 
      so2_mass_lb_for_electricity, 
      co2e_mass_lb_for_electricity
      ) ~ -1 + net_generation_mwh
  }
  # Regression with split
  median_mod = feols(
    data = plant_dt,
    fml = fml_in
  )
  # NOX: Re-estimating if marginal rate is negative
  if(length(coef(median_mod[lhs = 'nox'][[1]])) == 2){
    if(
      (coef(median_mod[lhs = 'nox'][[1]], keep = 'net_generation_mwh') + 
       coef(median_mod[lhs = 'nox'][[1]], keep = 'net_gen_over_split') < 0
      )|coef(median_mod[lhs = 'nox'][[1]], keep = 'net_generation_mwh') < 0
    ){# Fitting model without median indicator
      rate_mod_nox = feols(
        data = plant_dt, 
        fml = nox_mass_lb_for_electricity ~ -1 + net_generation_mwh 
      )
      # Collecting the results 
      median_out_nox = 
        rate_mod_nox |> 
        tidy() |> 
        mutate(
          pollutant = 'nox',
          plant_id_eia = plant_id_eia_in,
          split = split,
          net_generation_mwh_split = plant_quants[1]
        )
    }else{
      median_out_nox = 
        median_mod[lhs = 'nox'][[1]] |> 
        tidy() |> 
        mutate(
          pollutant = 'nox',
          plant_id_eia = plant_id_eia_in,
          split = split,
          net_generation_mwh_split = plant_quants[1]
        )
    }
  } else if(length(coef(median_mod[lhs = 'nox'][[1]])) == 1) {
    median_out_nox = 
      median_mod[lhs = 'nox'][[1]] |> 
      tidy() |> 
      mutate(
        pollutant = 'nox',
        plant_id_eia = plant_id_eia_in,
        split = split,
        net_generation_mwh_split = plant_quants[1]
      )
  } else {
    median_out_nox = tibble(
      term = NA,
      estimate = NA, 
      std.error = NA,
      statistic = NA, 
      p.value = NA,
      pollutant = 'nox',
      plant_id_eia = plant_id_eia_in,
      split = split,
      net_generation_mwh_split = plant_quants[1]
    )
  }
  # SO2: Re-estimating if marginal rate is negative
  if(length(coef(median_mod[lhs = 'so2'][[1]])) == 2){
    if(
      (coef(median_mod[lhs = 'so2'][[1]], keep = 'net_generation_mwh') + 
       coef(median_mod[lhs = 'so2'][[1]], keep = 'net_gen_over_split') < 0
      )|coef(median_mod[lhs = 'so2'][[1]], keep = 'net_generation_mwh') < 0
    ){# Fitting model without median indicator
      rate_mod_so2 = feols(
        data = plant_dt, 
        fml = so2_mass_lb_for_electricity ~ -1 + net_generation_mwh 
      )
      # Collecting the results 
      median_out_so2 = 
        rate_mod_so2 |> 
        tidy() |> 
        mutate(
          pollutant = 'so2',
          plant_id_eia = plant_id_eia_in,
          split = split,
          net_generation_mwh_split = plant_quants[1]
        )
    }else{
      median_out_so2 = 
        median_mod[lhs = 'so2'][[1]] |> 
        tidy() |> 
        mutate(
          pollutant = 'so2',
          plant_id_eia = plant_id_eia_in,
          split = split,
          net_generation_mwh_split = plant_quants[1]
        )
    }
  } else if(length(coef(median_mod[lhs = 'so2'][[1]])) == 1) {
    median_out_so2 = 
      median_mod[lhs = 'so2'][[1]] |> 
      tidy() |> 
      mutate(
        pollutant = 'so2',
        plant_id_eia = plant_id_eia_in,
        split = split,
        net_generation_mwh_split = plant_quants[1]
      )
  } else {
    median_out_so2 = tibble(
      term = NA,
      estimate = NA, 
      std.error = NA,
      statistic = NA, 
      p.value = NA,
      pollutant = 'so2',
      plant_id_eia = plant_id_eia_in,
      split = split,
      net_generation_mwh_split = plant_quants[1]
    )
  }
  # co2e: Re-estimating if marginal rate is negative
  if(length(coef(median_mod[lhs = 'co2e'][[1]])) == 2){
    if(
      (coef(median_mod[lhs = 'co2e'][[1]], keep = 'net_generation_mwh') + 
       coef(median_mod[lhs = 'co2e'][[1]], keep = 'net_gen_over_split') < 0
      ) | coef(median_mod[lhs = 'co2e'][[1]], keep = 'net_generation_mwh') < 0
    ){# Fitting model without median indicator
      rate_mod_co2e = feols(
        data = plant_dt, 
        fml = co2e_mass_lb_for_electricity ~ -1 + net_generation_mwh 
      )
      # Collecting the results 
      median_out_co2e = 
        rate_mod_co2e |> 
        tidy() |> 
        mutate(
          pollutant = 'co2e',
          plant_id_eia = plant_id_eia_in,
          split = split,
          net_generation_mwh_split = plant_quants[1]
        )
    }else{
      median_out_co2e = 
        median_mod[lhs = 'co2e'][[1]] |> 
        tidy() |> 
        mutate(
          pollutant = 'co2e',
          plant_id_eia = plant_id_eia_in,
          split = split,
          net_generation_mwh_split = plant_quants[1]
        )
    }
  } else if(length(coef(median_mod[lhs = 'co2e'][[1]])) == 1) {
    median_out_co2e = 
      median_mod[lhs = 'co2e'][[1]] |> 
      tidy() |> 
      mutate(
        pollutant = 'co2e',
        plant_id_eia = plant_id_eia_in,
        split = split,
        net_generation_mwh_split = plant_quants[1]
      )
  }else{
    median_out_co2e = tibble(
      term = NA,
      estimate = NA, 
      std.error = NA,
      statistic = NA, 
      p.value = NA,
      pollutant = 'co2e',
      plant_id_eia = plant_id_eia_in,
      split = split,
      net_generation_mwh_split = plant_quants[1]
    )
  }
  # Returning the results 
  return(
    rbind(median_out_nox, median_out_so2, median_out_co2e) %>% as.data.table() 
  )
}  

estimate_marginal_emissions_rate_nerc = function(
  nerc_adj_in,
  time_bounds =  c(ymd_hms('2019-01-01 10:00:00'), ymd_hms('2019-12-31 23:00:00'))
){
  print(paste(nerc_adj_in,'loading data'))
  # Hourly emissions data
  elec_gen_dt =
    read.fst(
      path = here(
        "Data/electricity-generation",
        paste0('elec-gen-dt/elec-gen-dt-',tolower(nerc_adj_in), ".fst")
      ),
      as.data.table = TRUE
    )[ # Removing hours above nameplate capacity
      datetime_utc %between% c(time_bounds[1],time_bounds[2])& 
      above_namepcap == FALSE
    ] |> 
    setkey(plant_id_eia, datetime_utc)
  # Getting plants that we have production models for 
  plant_id_eia_prod_model = 
    list.files(here("Data/electricity-generation/plant-model-fit/coefs")) |> 
    str_remove("plant-mod-coefs-") |>
    str_remove("\\.fst") |>
    as.integer()
  # Plants to run 
  all_plants = unique(elec_gen_dt$plant_id_eia)
  plants_to_run = all_plants[all_plants %in% plant_id_eia_prod_model]
  # Running models for all units 
  print(paste(nerc_adj_in,'fiting models'))
  emissions_rate_dt = 
    map_dfr(
      plants_to_run,
      emissions_rates_median,
      elec_gen_dt = elec_gen_dt,
      print = TRUE
    )
    print(paste(nerc_adj_in,'done'))
  return(emissions_rate_dt)
}


estimate_marginal_emissions_rate = function(
  nerc_adj_in = c('CAL','MRO','NPCC','RFC','SERC','TRE','WECC')
){
  emissions_rate_dt = map_dfr(
    nerc_adj_in, 
    estimate_marginal_emissions_rate_nerc
  )
   # Cleaning up the results
  clean_emissions_rate_dt = 
    rbind(
      # low term
      emissions_rate_dt[term == 'net_generation_mwh',.(
        plant_id_eia, pollutant, 
        output = 'emissions_rate_low',
        net_generation_mwh_split,
        emissions_rate = estimate
      )],
      # high term
      dcast(
        emissions_rate_dt[!is.na(term)],
        plant_id_eia + pollutant + net_generation_mwh_split ~ term, 
        value.var = 'estimate', 
        fill= 0
      )[,.(
        plant_id_eia, pollutant, 
        output = 'emissions_rate_high',
        net_generation_mwh_split,
        emissions_rate = net_generation_mwh + `net_gen_over_split::TRUE:net_gen_msplit`
      )]
    ) 
  # Saving results 
  write.fst(
    emissions_rate_dt,
    here('Data/electricity-generation/plant-model-fit/emissions-rate-dt-raw.fst')
  )
  write.fst(
    clean_emissions_rate_dt,
    here('Data/electricity-generation/plant-model-fit/emissions-rate-dt.fst')
  )
  write.csv(
    clean_emissions_rate_dt,
    here('Data/electricity-generation/output/emissions-rate-median-oge.csv')
  )
  print('done.')
}

